import gymnasium as gym
import numpy as np
from gymnasium import spaces
from gymnasium.wrappers import FlattenObservation
import sys, os


class FlattenDictObservation(FlattenObservation):
    def __init__(self, env: gym.Env):
        """Flattens the observations of an environment.

        Args:
            env: The environment to apply the wrapper
        """
        # TODO: check if goal_based
        self.goal_based = False

        self.dict_obs_space = env.observation_space
        self.num_factors = len(env.observation_space.spaces)

        # get state to factor mapping
        self.breakpoints = [0]
        self.factor_spaces = []
        for obs_k, obs_space in env.observation_space.spaces.items():
            if isinstance(obs_space, spaces.Box):
                assert len(obs_space.shape) == 1
                self.breakpoints.append(self.breakpoints[-1] + np.sum(obs_space.shape[0]))
            elif isinstance(obs_space, spaces.MultiDiscrete):
                self.breakpoints.append(self.breakpoints[-1] + np.sum(obs_space.nvec))
            else:
                raise NotImplementedError
            self.factor_spaces.append(obs_space)
        self.breakpoints = np.array(self.breakpoints)

        super().__init__(env)

    def __getattr__(self, name: str):
        """Returns an attribute with ``name``, unless ``name`` starts with an underscore.

        Args:
            name: The variable name

        Returns:
            The value of the variable in the wrapper stack

        Warnings:
            This feature is deprecated and removed in v1.0 and replaced with `env.get_attr(name})`
        """
        if name == "_np_random":
            raise AttributeError(
                "Can't access `_np_random` of a wrapper, use `self.unwrapped._np_random` or `self.np_random`."
            )
        elif name.startswith("_"):
            raise AttributeError(f"accessing private attribute '{name}' is prohibited")
        if isinstance(self.env, gym.Wrapper):
            return getattr(self.env.unwrapped, name)
        else:
            return getattr(self.env, name)

class FlattenFactorObservation(FlattenObservation):
    if os.path.join(sys.path[0],"Causal", "ac_infer") not in sys.path: sys.path.append(os.path.join(sys.path[0],"Causal", "ac_infer"))
    from Causal.ac_infer.Environment.environment import Environment
    def __init__(self, env: Environment):
        """Flattens the observations of an environment.
        For environments from ac_infer

        Args:
            env: The environment to apply the wrapper
        """
        # TODO: check if goal_based
        if os.path.join(sys.path[0],"Causal", "ac_infer") not in sys.path: sys.path.append(os.path.join(sys.path[0],"Causal", "ac_infer"))
        from Causal.ac_infer.Environment.environment import strip_instance, non_state_factors
        self.goal_based = False
        self.non_state_factors = non_state_factors
        self.all_names = env.all_names

        self.num_factors = env.num_objects - 3 # Action, Reward, Done

        # get state to factor mapping
        self.breakpoints = [0]
        self.factor_spaces = []
        for i, name in enumerate(env.all_names):
            if name not in non_state_factors:
                obs_space = spaces.Box(low=env.object_range[strip_instance(name)][0], high= env.object_range[strip_instance(name)][0])
                self.breakpoints.append(self.breakpoints[-1] + env.object_sizes[strip_instance(name)])
                self.factor_spaces.append(obs_space)
        self.breakpoints = np.array(self.breakpoints)

        super().__init__(env)

    def __getattr__(self, name: str):
        """Returns an attribute with ``name``, unless ``name`` starts with an underscore.

        Args:
            name: The variable name

        Returns:
            The value of the variable in the wrapper stack

        Warnings:
            This feature is deprecated and removed in v1.0 and replaced with `env.get_attr(name})`
        """
        if name == "_np_random":
            raise AttributeError(
                "Can't access `_np_random` of a wrapper, use `self.unwrapped._np_random` or `self.np_random`."
            )
        elif name.startswith("_"):
            raise AttributeError(f"accessing private attribute '{name}' is prohibited")
        if isinstance(self.env, gym.Wrapper):
            return getattr(self.env.unwrapped, name)
        else:
            return getattr(self.env, name)

    def observation(self, observation):
        """Flattens an observation.

        Args:
            observation: The observation to flatten

        Returns:
            The flattened observation.
            if goal based, achieved_goal, desired_goal, observation assigned appropriately (based on factored state parameters)
        """
        # print(self.env.observation_space, [observation["factored_state"][n] for n in self.all_names if n not in self.non_state_factors])
        # print(self.env.observation_space._shape, [observation["factored_state"][n].shape[0] for n in self.all_names if n not in self.non_state_factors], sum([observation["factored_state"][n].shape[0] for n in self.all_names if n not in self.non_state_factors]))
        # return spaces.flatten(self.env.observation_space, [observation["factored_state"][n] for n in self.all_names if n not in self.non_state_factors])
        # print(self.env.observation_space, np.concatenate([observation["factored_state"][n] for n in self.all_names if n not in self.non_state_factors]), [n for n in self.all_names if n not in self.non_state_factors])
        return spaces.flatten(self.env.observation_space, np.concatenate([observation["factored_state"][n] for n in self.all_names if n not in self.non_state_factors]))
